from typing import List, NamedTuple, Optional, Union

import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.utils import is_torch_sparse_tensor
from torch_geometric.utils.sparse import ptr2index
from torch_geometric.typing import SparseTensor


class CollectArgs(NamedTuple):
{%- if collect_param_dict|length > 0 %}
{%- for param in collect_param_dict.values() %}
    {{param.name}}: {{param.type_repr}}
{%- endfor %}
{%- else %}
    pass
{%- endif %}


def {{collect_name}}(
    self,
    edge_index: Union[Tensor, SparseTensor],
{%- for param in signature.param_dict.values() %}
    {{param.name}}: {{param.type_repr}},
{%- endfor %}
    size: List[Optional[int]],
) -> CollectArgs:

    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

    # Collect special arguments:
    if isinstance(edge_index, Tensor):
        if is_torch_sparse_tensor(edge_index):
{%- if 'edge_index' in collect_param_dict %}
            raise ValueError("Cannot collect 'edge_indices' for sparse matrices")
{%- endif %}
            adj_t = edge_index
            if adj_t.layout == torch.sparse_coo:
                edge_index_i = adj_t.indices()[0]
                edge_index_j = adj_t.indices()[1]
                ptr = None
            elif adj_t.layout == torch.sparse_csr:
                ptr = adj_t.crow_indices()
                edge_index_j = adj_t.col_indices()
                edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())
            else:
                raise ValueError(f"Received invalid layout '{adj_t.layout}'")

{%- if 'edge_weight' in collect_param_dict %}
            if edge_weight is None:
                edge_weight = adj_t.values()
{%- elif 'edge_attr' in collect_param_dict %}
            if edge_attr is None:
                _value = adj_t.values()
                edge_attr = None if _value.dim() == 1 else _value
{%- elif 'edge_type' in collect_param_dict %}
            if edge_type is None:
                edge_type = adj_t.values()
{%- endif %}

        else:
{%- if 'adj_t' in collect_param_dict %}
            raise ValueError("Cannot collect 'adj_t' for edge indices")
{%- endif %}
            edge_index_i = edge_index[i]
            edge_index_j = edge_index[j]

            ptr = None
            if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
                if i == 0 and edge_index.is_sorted_by_row:
                  (ptr, _), _ = edge_index.get_csr()
                elif i == 1 and edge_index.is_sorted_by_col:
                  (ptr, _), _ = edge_index.get_csc()

    elif isinstance(edge_index, SparseTensor):
{%- if 'edge_index' in collect_param_dict %}
        raise ValueError("Cannot collect 'edge_indices' for sparse matrices")
{%- endif %}
        adj_t = edge_index
        edge_index_i, edge_index_j, _value = adj_t.coo()
        ptr, _, _ = adj_t.csr()

{%- if 'edge_weight' in collect_param_dict %}
        if edge_weight is None:
            edge_weight = _value
{%- elif 'edge_attr' in collect_param_dict %}
        if edge_attr is None:
            edge_attr = None if _value is None or _value.dim() == 1 else _value
{%- elif 'edge_type' in collect_param_dict %}
        if edge_type is None:
            edge_type = _value
{%- endif %}

    else:
        raise NotImplementedError

{%- if 'edge_weight' in collect_param_dict and
    collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
    assert edge_weight is not None
{%- elif 'edge_attr' in collect_param_dict and
    collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
    assert edge_attr is not None
{%- elif 'edge_type' in collect_param_dict and
    collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
    assert edge_type is not None
{%- endif %}

    # Collect user-defined arguments:
{%- for name in collect_param_dict %}
{%- if (name.endswith('_i') or name.endswith('_j')) and
        name not in ['edge_index_i', 'edge_index_j', 'size_i', 'size_j'] %}
    # ({{loop.index}}) - Collect `{{name}}`:
    if isinstance({{name[:-2]}}, (tuple, list)):
        assert len({{name[:-2]}}) == 2
        _{{name[:-2]}}_0, _{{name[:-2]}}_1 = {{name[:-2]}}[0], {{name[:-2]}}[1]
        if isinstance(_{{name[:-2]}}_0, Tensor):
            self._set_size(size, 0, _{{name[:-2]}}_0)
{%- if name.endswith('_j') %}
            {{name}} = self._index_select(_{{name[:-2]}}_0, edge_index_{{name[-1]}})
        else:
            {{name}} = None
{%- endif %}
        if isinstance(_{{name[:-2]}}_1, Tensor):
            self._set_size(size, 1, _{{name[:-2]}}_1)
{%- if name.endswith('_i') %}
            {{name}} = self._index_select(_{{name[:-2]}}_1, edge_index_{{name[-1]}})
        else:
            {{name}} = None
{%- endif %}
    elif isinstance({{name[:-2]}}, Tensor):
        self._set_size(size, {{name[-1]}}, {{name[:-2]}})
        {{name}} = self._index_select({{name[:-2]}}, edge_index_{{name[-1]}})
    else:
        {{name}} = None
{%- endif %}
{%- endfor %}

    # Collect default arguments:
{%- for name, param in collect_param_dict.items() %}
{%- if name not in signature.param_dict and
       not name.endswith('_i') and
       not name.endswith('_j') and
       name not in ['edge_index', 'adj_t', 'size', 'ptr', 'index', 'dim_size'] and
       '_empty' not in param.default.__name__ %}
    {{name}} = {{param.default}}
{%- endif %}
{%- endfor %}

    index = edge_index_i
    size_i = size[i] if size[i] is not None else size[j]
    size_j = size[j] if size[j] is not None else size[i]
    dim_size = size_i

    return CollectArgs(
{%- for name in collect_param_dict %}
        {{name}},
{%- endfor %}
    )
